{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "nbsphinx": "hidden"
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "os.sys.path.insert(0, '/home/schirrmr/braindecode/code/braindecode/')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Read and Decode BBCI Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This tutorial shows how to read and decode BBCI data."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup logging to see outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import logging\n",
    "import sys\n",
    "logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',\n",
    "                     level=logging.DEBUG, stream=sys.stdout)\n",
    "log = logging.getLogger()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load and preprocess data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First set the filename and the sensors you want to load. If you set\n",
    "\n",
    "```python\n",
    "load_sensor_names=None\n",
    "```\n",
    "\n",
    "or just remove the parameter from the function call, all sensors will be loaded."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from braindecode.datasets.bbci import BBCIDataset\n",
    "train_filename = '/home/schirrmr/data/BBCI-without-last-runs/BhNoMoSc1S001R01_ds10_1-12.BBCI.mat'\n",
    "cnt = BBCIDataset(train_filename, load_sensor_names=['C3', 'CPz', 'C4']).load()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preprocessing on continous data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First remove the stimulus channel, than apply any preprocessing you like. There are some very few directions available from Braindecode, such as resample_cnt. But you can apply any function on the chan x time matrix of the mne raw object (`cnt` in the code) by calling `mne_apply` with two arguments:\n",
    "\n",
    "1. Your function (2d-array-> 2darray), that transforms the channel x timesteps data array\n",
    "2. the Raw data object from mne itself"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from braindecode.mne_ext.signalproc import resample_cnt, mne_apply\n",
    "from braindecode.datautil.signalproc import exponential_running_standardize\n",
    "# Remove stimulus channel\n",
    "cnt = cnt.drop_channels(['STI 014'])\n",
    "cnt = resample_cnt(cnt, 250)\n",
    "# mne apply will apply the function to the data (a 2d-numpy-array)\n",
    "# have to transpose data back and forth, since\n",
    "# exponential_running_standardize expects time x chans order\n",
    "# while mne object has chans x time order\n",
    "cnt = mne_apply(lambda a: exponential_running_standardize(\n",
    "    a.T, init_block_size=1000,factor_new=0.001, eps=1e-4).T,\n",
    "    cnt)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Transform to epoched dataset "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Braindecode supplies the `create_signal_target_from_raw_mne` function, which will transform the mne raw object into a `SignalAndTarget` object for use in Braindecode.\n",
    "`name_to_code` should be an `OrderedDict` that maps class names to either one or a list of marker codes for that class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from braindecode.datautil.trial_segment import create_signal_target_from_raw_mne\n",
    "from collections import OrderedDict\n",
    "# can also give lists of marker codes in case a class has multiple marker codes...\n",
    "name_to_code = OrderedDict([('Right', 1), ('Left', 2), ('Rest', 3), ('Feet', 4)])\n",
    "segment_ival_ms = [-500,4000]\n",
    "\n",
    "train_set = create_signal_target_from_raw_mne(cnt, name_to_code, segment_ival_ms)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Same for test set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "test_filename = '/home/schirrmr/data/BBCI-only-last-runs/BhNoMoSc1S001R13_ds10_1-2BBCI.mat'\n",
    "cnt = BBCIDataset(test_filename, load_sensor_names=['C3', 'CPz', 'C4']).load()\n",
    "# Remove stimulus channel\n",
    "cnt = cnt.drop_channels(['STI 014'])\n",
    "cnt = resample_cnt(cnt, 250)\n",
    "cnt = mne_apply(lambda a: exponential_running_standardize(\n",
    "    a.T, init_block_size=1000,factor_new=0.001, eps=1e-4).T,\n",
    "    cnt)\n",
    "test_set = create_signal_target_from_raw_mne(cnt, name_to_code, segment_ival_ms)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-info\">\n",
    "\n",
    "In case of start and stop markers, provide a `name_to_stop_codes` dictionary (same as for the start codes in this example) as a final argument to `create_signal_target_from_raw_mne`. See [Read and Decode BBCI Data with Start-Stop-Markers Tutorial](BBCI_Data_Start_Stop.html)\n",
    "\n",
    "\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Split off a validation set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from braindecode.datautil.splitters import split_into_two_sets\n",
    "\n",
    "train_set, valid_set = split_into_two_sets(train_set, first_set_fraction=0.8)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from braindecode.models.shallow_fbcsp import ShallowFBCSPNet\n",
    "from torch import nn\n",
    "from braindecode.torch_ext.util import set_random_seeds\n",
    "\n",
    "# Set if you want to use GPU\n",
    "# You can also use torch.cuda.is_available() to determine if cuda is available on your machine.\n",
    "cuda = True\n",
    "set_random_seeds(seed=20170629, cuda=cuda)\n",
    "\n",
    "\n",
    "# This will determine how many crops are processed in parallel\n",
    "input_time_length = train_set.X.shape[2]\n",
    "in_chans = 3\n",
    "n_classes = 4\n",
    "# final_conv_length determines the size of the receptive field of the ConvNet\n",
    "model = ShallowFBCSPNet(in_chans=in_chans, n_classes=n_classes, input_time_length=input_time_length,\n",
    "                        final_conv_length='auto').create_network()\n",
    "\n",
    "if cuda:\n",
    "    model.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup optimizer and iterator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from torch import optim\n",
    "import numpy as np\n",
    "\n",
    "optimizer = optim.Adam(model.parameters())\n",
    "\n",
    "\n",
    "from braindecode.datautil.iterators import BalancedBatchSizeIterator\n",
    "iterator = BalancedBatchSizeIterator(batch_size=32)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup Monitors, Loss function, Stop Criteria"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from braindecode.experiments.experiment import Experiment\n",
    "from braindecode.experiments.monitors import RuntimeMonitor, LossMonitor, CroppedTrialMisclassMonitor, MisclassMonitor\n",
    "from braindecode.experiments.stopcriteria import MaxEpochs\n",
    "import torch.nn.functional as F\n",
    "import torch as th\n",
    "from braindecode.torch_ext.modules import Expression\n",
    "\n",
    "\n",
    "loss_function = F.nll_loss\n",
    "\n",
    "model_constraint = None\n",
    "monitors = [LossMonitor(), MisclassMonitor(col_suffix='misclass'), \n",
    "            RuntimeMonitor(),]\n",
    "stop_criterion = MaxEpochs(20)\n",
    "exp = Experiment(model, train_set, valid_set, test_set, iterator, loss_function, optimizer, model_constraint,\n",
    "          monitors, stop_criterion, remember_best_column='valid_misclass',\n",
    "          run_after_early_stop=True, batch_modifier=None, cuda=cuda)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "exp.run()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We arrive around 26%, exact value depending on stars :))"
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Edit Metadata",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
